We ensure that the sampling function is explicitly called during the image generation process after obtaining z_mean and z_log_var.¶
Generate different images each time by explicitly calling the sampling function with new random noise.
In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
The process of sampling the latent vector involves generating a random sample from the distribution defined by 𝑧 mean and z log_var. This is done using the reparameterization trick, which allows gradients to flow through the sampling process during training.
In [2]:
# Sampling function with randomness
def sampling(z_mean, z_log_var):
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
In [3]:
# VAE Class with custom call method
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
def call(self, inputs):
z_mean, z_log_var, z = self.encoder(inputs)
reconstructed = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
tf.keras.losses.binary_crossentropy(inputs, reconstructed)
)
reconstruction_loss *= 200 * 200 * 3
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5
self.add_loss(reconstruction_loss + kl_loss)
return reconstructed
In [ ]:
In [4]:
# Encoder
latent_dim = 2
encoder_inputs = keras.Input(shape=(200, 200, 3))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = layers.Lambda(lambda args: sampling(*args), output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()
Model: "encoder"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer (InputLayer) │ (None, 200, 200, 3) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ conv2d (Conv2D) │ (None, 100, 100, 32) │ 896 │ input_layer[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ conv2d_1 (Conv2D) │ (None, 50, 50, 64) │ 18,496 │ conv2d[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ flatten (Flatten) │ (None, 160000) │ 0 │ conv2d_1[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ dense (Dense) │ (None, 16) │ 2,560,016 │ flatten[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ z_mean (Dense) │ (None, 2) │ 34 │ dense[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ z_log_var (Dense) │ (None, 2) │ 34 │ dense[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ z (Lambda) │ (None, 2) │ 0 │ z_mean[0][0], │ │ │ │ │ z_log_var[0][0] │ └───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘
Total params: 2,579,476 (9.84 MB)
Trainable params: 2,579,476 (9.84 MB)
Non-trainable params: 0 (0.00 B)
In [ ]:
In [5]:
# Decoder
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(50 * 50 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((50, 50, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(3, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name='decoder')
decoder.summary()
Model: "decoder"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ input_layer_1 (InputLayer) │ (None, 2) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense_1 (Dense) │ (None, 160000) │ 480,000 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ reshape (Reshape) │ (None, 50, 50, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_transpose (Conv2DTranspose) │ (None, 100, 100, 64) │ 36,928 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_transpose_1 (Conv2DTranspose) │ (None, 200, 200, 32) │ 18,464 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_transpose_2 (Conv2DTranspose) │ (None, 200, 200, 3) │ 867 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 536,259 (2.05 MB)
Trainable params: 536,259 (2.05 MB)
Non-trainable params: 0 (0.00 B)
In [ ]:
In [6]:
# VAE Model
vae = VAE(encoder, decoder)
vae.compile(optimizer='adam')
vae.build(input_shape=(None, 200, 200, 3))
vae.summary()
# Prepare and normalize the image
pic_2 = keras.preprocessing.image.load_img('pic_2.jpeg', target_size=(200, 200))
pic_2 = keras.preprocessing.image.img_to_array(pic_2).astype("float32") / 255
pic_2 = np.expand_dims(pic_2, 0)
# Train the VAE model
vae.fit(pic_2, epochs=100, batch_size=1)
Model: "vae"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ encoder (Functional) │ ? │ 2,579,476 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ decoder (Functional) │ ? │ 536,259 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 3,115,735 (11.89 MB)
Trainable params: 3,115,735 (11.89 MB)
Non-trainable params: 0 (0.00 B)
Epoch 1/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step - loss: 83178.2500 Epoch 2/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 54ms/step - loss: 83129.2500 Epoch 3/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 83552.8594 Epoch 4/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 82888.7812 Epoch 5/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 83003.5703 Epoch 6/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 82873.5703 Epoch 7/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 80328.5781 Epoch 8/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 103015.8984 Epoch 9/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 81020.2734 Epoch 10/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 82548.5391 Epoch 11/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 82767.2734 Epoch 12/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 82772.3750 Epoch 13/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 82759.4844 Epoch 14/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 82740.7656 Epoch 15/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 82719.9766 Epoch 16/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 82767.0391 Epoch 17/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 82705.4688 Epoch 18/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 82680.0156 Epoch 19/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 82687.8828 Epoch 20/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 82682.8281 Epoch 21/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 82598.8672 Epoch 22/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 114ms/step - loss: 82637.7812 Epoch 23/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 82592.1875 Epoch 24/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step - loss: 82544.9453 Epoch 25/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 54ms/step - loss: 82512.3828 Epoch 26/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 82422.1797 Epoch 27/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 82315.6719 Epoch 28/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 82401.1406 Epoch 29/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 82375.1953 Epoch 30/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 82416.9453 Epoch 31/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 82308.6953 Epoch 32/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 82082.1953 Epoch 33/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 82227.6172 Epoch 34/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 82238.7031 Epoch 35/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 81953.7969 Epoch 36/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 82037.8750 Epoch 37/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step - loss: 82018.2031 Epoch 38/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 82014.2812 Epoch 39/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 81661.3438 Epoch 40/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 80818.4844 Epoch 41/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 82053.0547 Epoch 42/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 79494.8359 Epoch 43/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 80898.1875 Epoch 44/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 79759.9375 Epoch 45/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 79260.4375 Epoch 46/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step - loss: 76992.2734 Epoch 47/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 79651.1875 Epoch 48/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 81640.4766 Epoch 49/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 81179.8125 Epoch 50/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 68ms/step - loss: 80421.4844 Epoch 51/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 90ms/step - loss: 79459.7109 Epoch 52/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 79605.2266 Epoch 53/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 77276.2812 Epoch 54/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 68ms/step - loss: 76496.4766 Epoch 55/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 80ms/step - loss: 76753.7344 Epoch 56/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 68ms/step - loss: 74851.1719 Epoch 57/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 77927.4219 Epoch 58/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 74561.1562 Epoch 59/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 111720.9141 Epoch 60/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 73512.0859 Epoch 61/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 76679.8359 Epoch 62/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 76298.5703 Epoch 63/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 78089.7812 Epoch 64/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 78323.0625 Epoch 65/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 79091.9688 Epoch 66/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 79524.1953 Epoch 67/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 54ms/step - loss: 78808.1172 Epoch 68/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 77712.7109 Epoch 69/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 77794.5547 Epoch 70/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 78462.0469 Epoch 71/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 57ms/step - loss: 77828.1641 Epoch 72/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 52ms/step - loss: 77288.6953 Epoch 73/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 76320.1953 Epoch 74/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 76509.0156 Epoch 75/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 75829.4219 Epoch 76/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 74ms/step - loss: 74081.6094 Epoch 77/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 75310.8984 Epoch 78/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 70599.2656 Epoch 79/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 69802.5703 Epoch 80/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step - loss: 86031.2734 Epoch 81/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 68547.1094 Epoch 82/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 71655.6562 Epoch 83/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 74415.7734 Epoch 84/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step - loss: 72088.7812 Epoch 85/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 58ms/step - loss: 71316.7969 Epoch 86/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 86ms/step - loss: 72163.7656 Epoch 87/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 113ms/step - loss: 71978.0234 Epoch 88/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 71232.8516 Epoch 89/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 75ms/step - loss: 71490.9062 Epoch 90/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 54ms/step - loss: 71052.8516 Epoch 91/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 70762.4375 Epoch 92/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 70316.2109 Epoch 93/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step - loss: 71599.3281 Epoch 94/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step - loss: 68455.8125 Epoch 95/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 56ms/step - loss: 70260.8125 Epoch 96/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 53ms/step - loss: 68871.9062 Epoch 97/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 69570.7812 Epoch 98/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 67728.8359 Epoch 99/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 55ms/step - loss: 67470.1953 Epoch 100/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 66914.5156
Out[6]:
<keras.src.callbacks.history.History at 0x2a770907f10>
In [ ]:
In [7]:
# Show plot and save image function
def show_and_save_plot(image, save_path):
plt.imshow(image.squeeze())
plt.axis('off')
plt.savefig(save_path)
plt.show()
for i in range(10):
z_mean, z_log_var, _ = vae.encoder.predict(pic_2)
# Introduce randomness in sampling
encoded_imgs = sampling(z_mean, z_log_var).numpy()
decoded_imgs = vae.decoder.predict(encoded_imgs)
save_path = f'generated_image_{i + 1}.png'
show_and_save_plot(decoded_imgs, save_path)
print("Images have been saved and displayed successfully.")
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 86ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 60ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 35ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 35ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 37ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 35ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 22ms/step
Images have been saved and displayed successfully.
In [ ]:
In [ ]:
What output do you see? Explain why¶
For the codes above, the resulted 10 images vary much better than using pic_1¶
In [8]:
z_mean, z_log_var, _ = vae.encoder.predict(pic_2)
print("z_mean:", z_mean)
print("z_log_var:", z_log_var)
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 23ms/step z_mean: [[-0.8666203 -0.60099775]] z_log_var: [[0.26988217 0.63260573]]
I still want to try latent_dim = 10¶
In [9]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt
# Sampling function with randomness
def sampling(z_mean, z_log_var):
batch = tf.shape(z_mean)[0]
dim = tf.shape(z_mean)[1]
epsilon = tf.keras.backend.random_normal(shape=(batch, dim))
return z_mean + tf.exp(0.5 * z_log_var) * epsilon
# VAE Class with custom call method
class VAE(keras.Model):
def __init__(self, encoder, decoder, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.mse = tf.keras.losses.MeanSquaredError()
def call(self, inputs):
z_mean, z_log_var, z = self.encoder(inputs)
reconstructed = self.decoder(z)
reconstruction_loss = tf.reduce_mean(
self.mse(inputs, reconstructed)
)
reconstruction_loss *= 200 * 200 * 3
kl_loss = 1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var)
kl_loss = tf.reduce_mean(kl_loss)
kl_loss *= -0.5
self.add_loss(reconstruction_loss + kl_loss)
return reconstructed
# Encoder
latent_dim = 10 # Increased latent dimension
encoder_inputs = keras.Input(shape=(200, 200, 3))
x = layers.Conv2D(32, 3, activation="relu", strides=2, padding="same")(encoder_inputs)
x = layers.Conv2D(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Flatten()(x)
x = layers.Dense(16, activation="relu")(x)
z_mean = layers.Dense(latent_dim, name="z_mean")(x)
z_log_var = layers.Dense(latent_dim, name="z_log_var")(x)
z = layers.Lambda(lambda args: sampling(*args), output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
encoder = keras.Model(encoder_inputs, [z_mean, z_log_var, z], name='encoder')
encoder.summary()
# Decoder
latent_inputs = keras.Input(shape=(latent_dim,))
x = layers.Dense(50 * 50 * 64, activation="relu")(latent_inputs)
x = layers.Reshape((50, 50, 64))(x)
x = layers.Conv2DTranspose(64, 3, activation="relu", strides=2, padding="same")(x)
x = layers.Conv2DTranspose(32, 3, activation="relu", strides=2, padding="same")(x)
decoder_outputs = layers.Conv2DTranspose(3, 3, activation="sigmoid", padding="same")(x)
decoder = keras.Model(latent_inputs, decoder_outputs, name='decoder')
decoder.summary()
# VAE Model
vae = VAE(encoder, decoder)
vae.compile(optimizer='adam')
vae.build(input_shape=(None, 200, 200, 3))
vae.summary()
# Prepare and normalize the image
pic_2 = keras.preprocessing.image.load_img('pic_2.jpeg', target_size=(200, 200))
pic_2 = keras.preprocessing.image.img_to_array(pic_2).astype("float32") / 255
pic_2 = np.expand_dims(pic_2, 0)
# Train the VAE model
vae.fit(pic_2, epochs=100, batch_size=1)
# Show plot and save image function
def show_and_save_plot(image, save_path):
plt.imshow(image.squeeze())
plt.axis('off')
plt.savefig(save_path)
plt.show()
# Run the encoder-decoder multiple times to generate different copies
for i in range(10):
z_mean, z_log_var, _ = vae.encoder.predict(pic_2)
# Introduce randomness in sampling
encoded_imgs = sampling(z_mean, z_log_var).numpy()
decoded_imgs = vae.decoder.predict(encoded_imgs)
save_path = f'generated_image_{i + 1}.png'
show_and_save_plot(decoded_imgs, save_path)
print("Images have been saved and displayed successfully.")
Model: "encoder"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ Connected to ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━┩ │ input_layer_2 (InputLayer) │ (None, 200, 200, 3) │ 0 │ - │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ conv2d_2 (Conv2D) │ (None, 100, 100, 32) │ 896 │ input_layer_2[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ conv2d_3 (Conv2D) │ (None, 50, 50, 64) │ 18,496 │ conv2d_2[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ flatten_1 (Flatten) │ (None, 160000) │ 0 │ conv2d_3[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ dense_2 (Dense) │ (None, 16) │ 2,560,016 │ flatten_1[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ z_mean (Dense) │ (None, 10) │ 170 │ dense_2[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ z_log_var (Dense) │ (None, 10) │ 170 │ dense_2[0][0] │ ├───────────────────────────────┼───────────────────────────┼─────────────────┼────────────────────────────┤ │ z (Lambda) │ (None, 10) │ 0 │ z_mean[0][0], │ │ │ │ │ z_log_var[0][0] │ └───────────────────────────────┴───────────────────────────┴─────────────────┴────────────────────────────┘
Total params: 2,579,748 (9.84 MB)
Trainable params: 2,579,748 (9.84 MB)
Non-trainable params: 0 (0.00 B)
Model: "decoder"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ input_layer_3 (InputLayer) │ (None, 10) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ dense_3 (Dense) │ (None, 160000) │ 1,760,000 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ reshape_1 (Reshape) │ (None, 50, 50, 64) │ 0 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_transpose_3 (Conv2DTranspose) │ (None, 100, 100, 64) │ 36,928 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_transpose_4 (Conv2DTranspose) │ (None, 200, 200, 32) │ 18,464 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ conv2d_transpose_5 (Conv2DTranspose) │ (None, 200, 200, 3) │ 867 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 1,816,259 (6.93 MB)
Trainable params: 1,816,259 (6.93 MB)
Non-trainable params: 0 (0.00 B)
Model: "vae_1"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓ ┃ Layer (type) ┃ Output Shape ┃ Param # ┃ ┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩ │ encoder (Functional) │ ? │ 2,579,748 │ ├──────────────────────────────────────┼─────────────────────────────┼─────────────────┤ │ decoder (Functional) │ ? │ 1,816,259 │ └──────────────────────────────────────┴─────────────────────────────┴─────────────────┘
Total params: 4,396,007 (16.77 MB)
Trainable params: 4,396,007 (16.77 MB)
Non-trainable params: 0 (0.00 B)
Epoch 1/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 2s 2s/step - loss: 9969.8330 Epoch 2/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 9952.2275 Epoch 3/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 9905.1992 Epoch 4/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 9799.8447 Epoch 5/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 9443.6328 Epoch 6/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 8943.6104 Epoch 7/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 8196.6943 Epoch 8/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 6990.1729 Epoch 9/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 6634.2773 Epoch 10/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 6583.5356 Epoch 11/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 5908.4053 Epoch 12/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - loss: 5622.7407 Epoch 13/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 76ms/step - loss: 5539.8311 Epoch 14/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 76ms/step - loss: 5298.5249 Epoch 15/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 78ms/step - loss: 5238.3037 Epoch 16/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 5181.6914 Epoch 17/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 5127.2827 Epoch 18/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 5079.2339 Epoch 19/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 5033.1089 Epoch 20/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step - loss: 4981.9697 Epoch 21/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 73ms/step - loss: 4940.1826 Epoch 22/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 116ms/step - loss: 4901.8281 Epoch 23/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 85ms/step - loss: 4861.6597 Epoch 24/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 74ms/step - loss: 4831.6338 Epoch 25/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 74ms/step - loss: 4815.4307 Epoch 26/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 73ms/step - loss: 4781.6743 Epoch 27/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 73ms/step - loss: 4777.3623 Epoch 28/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 4735.8325 Epoch 29/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 4715.5664 Epoch 30/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 4692.4648 Epoch 31/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 68ms/step - loss: 4653.4199 Epoch 32/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 4624.4077 Epoch 33/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - loss: 4687.2280 Epoch 34/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 4562.8481 Epoch 35/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 4552.0024 Epoch 36/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 4529.4150 Epoch 37/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 68ms/step - loss: 4474.5405 Epoch 38/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 4438.9492 Epoch 39/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 80ms/step - loss: 4424.4653 Epoch 40/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 76ms/step - loss: 4372.9590 Epoch 41/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - loss: 4338.4277 Epoch 42/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 4306.4731 Epoch 43/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 4260.6147 Epoch 44/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 4218.7573 Epoch 45/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - loss: 4180.4497 Epoch 46/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 4124.2217 Epoch 47/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 4065.3379 Epoch 48/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 68ms/step - loss: 4011.7961 Epoch 49/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - loss: 3945.3098 Epoch 50/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 3866.5088 Epoch 51/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 3787.2458 Epoch 52/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 3693.9219 Epoch 53/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 3578.2249 Epoch 54/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 3462.9343 Epoch 55/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 73ms/step - loss: 3297.0491 Epoch 56/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 79ms/step - loss: 3129.5481 Epoch 57/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 2919.8857 Epoch 58/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 77ms/step - loss: 2685.6443 Epoch 59/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - loss: 2456.9216 Epoch 60/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step - loss: 2239.6851 Epoch 61/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 2009.8368 Epoch 62/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 1831.4513 Epoch 63/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 71ms/step - loss: 1681.6642 Epoch 64/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 1522.0482 Epoch 65/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 1371.6166 Epoch 66/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 1306.6349 Epoch 67/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 1223.9176 Epoch 68/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 1108.3302 Epoch 69/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 1037.7194 Epoch 70/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 64ms/step - loss: 983.1470 Epoch 71/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 947.4807 Epoch 72/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 909.9151 Epoch 73/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 873.0816 Epoch 74/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step - loss: 835.9836 Epoch 75/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 74ms/step - loss: 871.8846 Epoch 76/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 73ms/step - loss: 799.4135 Epoch 77/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 67ms/step - loss: 785.8033 Epoch 78/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 738.4094 Epoch 79/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 736.1783 Epoch 80/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 106ms/step - loss: 706.0345 Epoch 81/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 74ms/step - loss: 696.6925 Epoch 82/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 72ms/step - loss: 672.9388 Epoch 83/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 662.8980 Epoch 84/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 70ms/step - loss: 646.0948 Epoch 85/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 68ms/step - loss: 635.3652 Epoch 86/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 625.7963 Epoch 87/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 620.7348 Epoch 88/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 607.3733 Epoch 89/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 595.3042 Epoch 90/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 75ms/step - loss: 597.9349 Epoch 91/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 68ms/step - loss: 579.2706 Epoch 92/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 615.9193 Epoch 93/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 61ms/step - loss: 569.0283 Epoch 94/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 69ms/step - loss: 642.7772 Epoch 95/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 68ms/step - loss: 555.5855 Epoch 96/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 609.1636 Epoch 97/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 546.4082 Epoch 98/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 66ms/step - loss: 685.6516 Epoch 99/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 63ms/step - loss: 538.1050 Epoch 100/100 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 62ms/step - loss: 602.5829 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 82ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 59ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 20ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 17ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 31ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 21ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 18ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 33ms/step
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 19ms/step 1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 25ms/step
Images have been saved and displayed successfully.
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]:
In [ ]: